#!/bin/bash
# Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

SAGEMAKER_SINGLE_MODEL_REPO=/opt/ml/model/

# Use 'ready' for ping check in single-model endpoint mode, and use 'live' for ping check in multi-model endpoint model
# https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/rest_predict_v2.yaml#L10-L26
if [ -n "$SAGEMAKER_TRITON_OVERRIDE_PING_MODE" ]; then
    SAGEMAKER_TRITON_PING_MODE=${SAGEMAKER_TRITON_OVERRIDE_PING_MODE}
else
    SAGEMAKER_TRITON_PING_MODE="ready"
fi

# Note: in Triton on SageMaker, each model url is registered as a separate repository
# e.g., /opt/ml/models/<hash>/model. Specifying MME model repo path as /opt/ml/models causes Triton
# to treat it as an additional empty repository and changes
# the state of all models to be UNAVAILABLE in the model repository
# https://github.com/triton-inference-server/core/blob/main/src/model_repository_manager.cc#L914,L922
# On Triton, this path will be a dummy path as it's mandatory to specify a model repo when starting triton
SAGEMAKER_MULTI_MODEL_REPO=/tmp/sagemaker

SAGEMAKER_MODEL_REPO=${SAGEMAKER_SINGLE_MODEL_REPO}
is_mme_mode=false

if [ -n "$SAGEMAKER_MULTI_MODEL" ]; then
    if [ "$SAGEMAKER_MULTI_MODEL" == "true" ]; then
        mkdir -p ${SAGEMAKER_MULTI_MODEL_REPO}
        SAGEMAKER_MODEL_REPO=${SAGEMAKER_MULTI_MODEL_REPO}
        if [ -n "$SAGEMAKER_TRITON_OVERRIDE_PING_MODE" ]; then
            SAGEMAKER_TRITON_PING_MODE=${SAGEMAKER_TRITON_OVERRIDE_PING_MODE}
        else
            SAGEMAKER_TRITON_PING_MODE="live"
        fi
        is_mme_mode=true
        echo -e "Triton is running in SageMaker MME mode. Using Triton ping mode: \"${SAGEMAKER_TRITON_PING_MODE}\""
    fi
fi

SAGEMAKER_ARGS="--model-repository=${SAGEMAKER_MODEL_REPO}"
#Set model namespacing to true, but allow disabling if required
if [ -n "$SAGEMAKER_TRITON_DISABLE_MODEL_NAMESPACING" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --model-namespacing=${SAGEMAKER_TRITON_DISABLE_MODEL_NAMESPACING}"
else
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --model-namespacing=true"
fi
if [ -n "$SAGEMAKER_BIND_TO_PORT" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --sagemaker-port=${SAGEMAKER_BIND_TO_PORT}"
fi
if [ -n "$SAGEMAKER_SAFE_PORT_RANGE" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --sagemaker-safe-port-range=${SAGEMAKER_SAFE_PORT_RANGE}"
fi
if [ -n "$SAGEMAKER_TRITON_ALLOW_GRPC" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-grpc=${SAGEMAKER_TRITON_ALLOW_GRPC}"
else
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-grpc=false"
fi
if [ -n "$SAGEMAKER_TRITON_ALLOW_METRICS" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-metrics=${SAGEMAKER_TRITON_ALLOW_METRICS}"
else
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-metrics=false"
fi
if [ -n "$SAGEMAKER_TRITON_METRICS_PORT" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --metrics-port=${SAGEMAKER_TRITON_METRICS_PORT}"
fi
if [ -n "$SAGEMAKER_TRITON_GRPC_PORT" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --grpc-port=${SAGEMAKER_TRITON_GRPC_PORT}"
fi
if [ -n "$SAGEMAKER_TRITON_BUFFER_MANAGER_THREAD_COUNT" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --buffer-manager-thread-count=${SAGEMAKER_TRITON_BUFFER_MANAGER_THREAD_COUNT}"
fi
if [ -n "$SAGEMAKER_TRITON_THREAD_COUNT" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --sagemaker-thread-count=${SAGEMAKER_TRITON_THREAD_COUNT}"
fi
# Enable verbose logging by default. If env variable is specified, use value from env variable
if [ -n "$SAGEMAKER_TRITON_LOG_VERBOSE" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-verbose=${SAGEMAKER_TRITON_LOG_VERBOSE}"
else
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-verbose=true"
fi
if [ -n "$SAGEMAKER_TRITON_LOG_INFO" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-info=${SAGEMAKER_TRITON_LOG_INFO}"
fi
if [ -n "$SAGEMAKER_TRITON_LOG_WARNING" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-warning=${SAGEMAKER_TRITON_LOG_WARNING}"
fi
if [ -n "$SAGEMAKER_TRITON_LOG_ERROR" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-error=${SAGEMAKER_TRITON_LOG_ERROR}"
fi
if [ -n "$SAGEMAKER_TRITON_SHM_DEFAULT_BYTE_SIZE" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-default-byte-size=${SAGEMAKER_TRITON_SHM_DEFAULT_BYTE_SIZE}"
else
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-default-byte-size=16777216" #16MB
fi
if [ -n "$SAGEMAKER_TRITON_SHM_GROWTH_BYTE_SIZE" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-growth-byte-size=${SAGEMAKER_TRITON_SHM_GROWTH_BYTE_SIZE}"
else
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-growth-byte-size=1048576" #1MB
fi
if [ -n "$SAGEMAKER_TRITON_TENSORFLOW_VERSION" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=tensorflow,version=${SAGEMAKER_TRITON_TENSORFLOW_VERSION}"
fi
if [ -n "$SAGEMAKER_TRITON_MODEL_LOAD_GPU_LIMIT" ]; then
    num_gpus=$(nvidia-smi -L | wc -l)
    for ((i=0; i<${num_gpus}; i++)); do
        SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --model-load-gpu-limit ${i}:${SAGEMAKER_TRITON_MODEL_LOAD_GPU_LIMIT}"
    done
fi
if [ -n "$SAGEMAKER_TRITON_ADDITIONAL_ARGS" ]; then
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} ${SAGEMAKER_TRITON_ADDITIONAL_ARGS}"
fi


if [ "${is_mme_mode}" = false ] && [ -f "${SAGEMAKER_MODEL_REPO}/config.pbtxt" ]; then
    echo "ERROR: Incorrect directory structure."
    echo "       Model directory needs to contain the top level folder"
    exit 1
fi

# Validate SAGEMAKER_TRITON_INFERENCE_TYPE if set
if [ -n "$SAGEMAKER_TRITON_INFERENCE_TYPE" ]; then
    case "$SAGEMAKER_TRITON_INFERENCE_TYPE" in
        "infer"|"generate"|"generate_stream")
            # Valid value, continue
            ;;
        *)
            echo "ERROR: Invalid SAGEMAKER_TRITON_INFERENCE_TYPE '${SAGEMAKER_TRITON_INFERENCE_TYPE}'"
            echo "       Must be one of: infer, generate, generate_stream"
            exit 1
            ;;
    esac
fi

if [ "${is_mme_mode}" = false ] && [ -n "$SAGEMAKER_TRITON_DEFAULT_MODEL_NAME" ]; then
    if [ -d "${SAGEMAKER_MODEL_REPO}/$SAGEMAKER_TRITON_DEFAULT_MODEL_NAME" ]; then
        SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --load-model=${SAGEMAKER_TRITON_DEFAULT_MODEL_NAME}"
    else
        echo "ERROR: Directory with provided SAGEMAKER_TRITON_DEFAULT_MODEL_NAME ${SAGEMAKER_TRITON_DEFAULT_MODEL_NAME} does not exist"
        exit 1
    fi
elif [ "${is_mme_mode}" = false ]; then
    MODEL_DIRS=(`find "${SAGEMAKER_MODEL_REPO}" -mindepth 1 -maxdepth 1 -type d -printf "%f\n"`)
    case ${#MODEL_DIRS[@]} in
        0) echo "ERROR: No model found in model repository";
           exit 1
           ;;
        1) echo "WARNING: No SAGEMAKER_TRITON_DEFAULT_MODEL_NAME provided."
           echo "         Starting with the only existing model directory ${MODEL_DIRS[0]}";
           export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME=${MODEL_DIRS[0]}
           ;;
        *) echo "ERROR: More than 1 model directory found in model repository."
           echo "       Either provide a single directory or set SAGEMAKER_TRITON_DEFAULT_MODEL_NAME to run the ensemble backend."
           echo "       Directories found in model repository: ${MODEL_DIRS[@]}";
           exit 1
           ;;
    esac
    SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --load-model=${SAGEMAKER_TRITON_DEFAULT_MODEL_NAME}"
fi

tritonserver --allow-sagemaker=true --allow-http=false --model-control-mode=explicit $SAGEMAKER_ARGS
